import torch
from torch.nn import Linear, ReLU
from torch_geometric.nn import MessagePassing
import numpy as np

class DMPNN_front(MessagePassing):
    def __init__(self, in_channels, out_channels, cycle=False, depth=4):
        super().__init__(aggr='mean')
        self.depth = depth
        self.m_lin_init = Linear(in_channels[1], out_channels, bias=True)
        self.m_lin = Linear(out_channels, out_channels, bias=True)
        self.a_lin_init = Linear(in_channels[0], out_channels, bias=True)
        self.a_lin = Linear(2 * out_channels, out_channels, bias=True)
        self.relu = ReLU()
        self.reset_parameters()
        self.cycle = cycle

    def reset_parameters(self):
        self.m_lin_init.reset_parameters()
        self.m_lin.reset_parameters()
        self.a_lin_init.reset_parameters()
        self.a_lin.reset_parameters()

    def forward(self, x, batch, edge_index, edge_index_cycle, edge_attr, rev_edge_index, atom_map_idx=None, b_keep_idx=None, rev_edge_index_cycle=None):
        init_edge_attr = self.m_lin_init(edge_attr)
        edge_attr = self.relu(init_edge_attr)
        x = self.relu(self.a_lin_init(x))

        for i in range(self.depth):
            if torch.is_tensor(atom_map_idx):
                edge_attr = self.message_update(edge_attr, init_edge_attr, edge_index, x.size(0), rev_edge_index,
                                                atom_map_idx=atom_map_idx, b_keep_idx=b_keep_idx,
                                                rev_edge_index_cycle=rev_edge_index_cycle)

            else:
                edge_attr = self.message_update(edge_attr, init_edge_attr, edge_index, x.size(0), rev_edge_index,
                                                atom_map_idx=atom_map_idx, b_keep_idx=b_keep_idx,
                                                rev_edge_index_cycle=rev_edge_index_cycle)

        x = self.propagate(edge_index, edge_attr=edge_attr, x=x, size=(x.size(0), x.size(0)))

        return x, batch, edge_attr, init_edge_attr, edge_index, rev_edge_index, atom_map_idx, b_keep_idx, rev_edge_index_cycle

    def message_update(self, edge_attr, init_edge_attr, edge_index, size, rev_edge_index, atom_map_idx=None,
                       b_keep_idx=None, rev_edge_index_cycle=None):
        m = torch.from_numpy(np.zeros((size, edge_attr.shape[1]))).to(edge_attr.device)
        m = m.type(torch.float).index_add_(0, edge_index[1], edge_attr)

        if torch.is_tensor(atom_map_idx) and torch.sum(atom_map_idx) > 0:
            m[atom_map_idx[0][1]] = m[atom_map_idx[1][0]]
            m[atom_map_idx[1][1]] = m[atom_map_idx[0][0]]

            m = m[edge_index[1]] - edge_attr[rev_edge_index]
            m = self.m_lin(m)
            m = self.relu(m + init_edge_attr)
        else:
            m = m[edge_index[1]] - edge_attr[rev_edge_index]
            m = self.m_lin(m)
            m = self.relu(m + init_edge_attr)
        return m

    def message(self, edge_attr):
        return edge_attr

    def update(self, aggr_out, x):
        x = self.a_lin(torch.cat([x, aggr_out], dim=1))
        return x


class DMPNN_back(MessagePassing):
    def __init__(self, in_channels, out_channels, cycle=False, depth=4):
        super().__init__(aggr='mean')
        self.depth = depth
        self.m_lin = Linear(out_channels, out_channels, bias=True)
        self.a_lin = Linear(2 * out_channels, out_channels, bias=True)
        self.relu = ReLU()
        self.reset_parameters()
        self.cycle = cycle

    def reset_parameters(self):
        self.m_lin.reset_parameters()
        self.a_lin.reset_parameters()

    def forward(self, x, batch, edge_attr, init_edge_attr, edge_index, rev_edge_index, atom_map_idx, b_keep_idx, rev_edge_index_cycle):

        for i in range(self.depth):
            if torch.is_tensor(atom_map_idx):
                edge_attr = self.message_update(edge_attr, init_edge_attr, edge_index, x.size(0), rev_edge_index,
                                                atom_map_idx=atom_map_idx, b_keep_idx=b_keep_idx,
                                                rev_edge_index_cycle=rev_edge_index_cycle)

            else:
                edge_attr = self.message_update(edge_attr, init_edge_attr, edge_index, x.size(0), rev_edge_index,
                                                atom_map_idx=atom_map_idx, b_keep_idx=b_keep_idx,
                                                rev_edge_index_cycle=rev_edge_index_cycle)

        x = self.propagate(edge_index, edge_attr=edge_attr, x=x, size=(x.size(0), x.size(0)))

        return x, edge_attr

    def message_update(self, edge_attr, init_edge_attr, edge_index, size, rev_edge_index, atom_map_idx=None,
                       b_keep_idx=None, rev_edge_index_cycle=None):

        m = torch.from_numpy(np.zeros((size, edge_attr.shape[1]))).to(edge_attr.device)
        m = m.type(torch.float).index_add_(0, edge_index[1], edge_attr)

        if torch.is_tensor(atom_map_idx) and torch.sum(atom_map_idx) > 0:
            m[atom_map_idx[0][1]] = m[atom_map_idx[1][0]]
            m[atom_map_idx[1][1]] = m[atom_map_idx[0][0]]

            m = m[edge_index[1]] - edge_attr[rev_edge_index]
            m = self.m_lin(m)
            m = self.relu(m + init_edge_attr)
        else:
            m = m[edge_index[1]] - edge_attr[rev_edge_index]
            m = self.m_lin(m)
            m = self.relu(m + init_edge_attr)
        return m

    def message(self, edge_attr):
        return edge_attr

    def update(self, aggr_out, x):
        x = self.a_lin(torch.cat([x, aggr_out], dim=1))
        return x
